"""Sample from posterior distribution."""
import numpy as np
import pandas as pd
import scipy.special

import datasets
import models
import pystan_cache

ITER = 20_000
SEED = 5
THIN = 4
logit, inv_logit = scipy.special.logit, scipy.special.expit


def posterior_raven_forster_1789_1799(seed=SEED):
    np.random.seed(seed)
    # Dirichlet Multinomial model using informative prior
    # n0 pseudo observations using known proportions from 1800
    # get proportions from 1800
    garside_schöwerling_1800 = datasets.garside_schöwerling_title_counts_by_gender().loc[1800]
    prior_proportions = garside_schöwerling_1800[['novels_unknown', 'novels_male', 'novels_female']].values
    prior_proportions = prior_proportions / prior_proportions.sum()
    n0 = 8  # prior
    alpha0 = n0 * prior_proportions  # prior dirichlet parameters

    num_draws = 4 * (ITER // 2) // THIN  # 4 chains, warmup
    assert num_draws == 10_000, num_draws

    # create a dictionary which mimics what fit.extract() returns.
    # we have 11 years (1789-1899) so the shape will be (num_draws, 11)
    y_unknown_sim, y_men_sim, y_women_sim = np.zeros((num_draws, 11)), np.zeros((num_draws, 11)), np.zeros((num_draws, 11))  # noqa
    for i, (year, row) in enumerate(datasets.raven_forster_1789_1799_random_sample_by_gender().iterrows()):
        alpha_post = row[['novels_unknown', 'novels_men', 'novels_women']].values + alpha0
        year_total = int(datasets.raven_forster_1789_1799().loc[year])
        unknown, men, women = np.hsplit(np.array([
            np.random.multinomial(year_total, np.random.dirichlet(alpha_post)) for _ in range(num_draws)]), 3)
        np.testing.assert_array_equal((unknown + men + women).ravel(), year_total)  # this passes
        y_unknown_sim[:, i] += unknown.ravel()
        y_men_sim[:, i] += men.ravel()
        y_women_sim[:, i] += women.ravel()
        np.testing.assert_array_equal(y_unknown_sim[:, i], unknown.ravel())
    y_sim = y_unknown_sim + y_men_sim + y_women_sim  # this is just the totals (known) repeated
    np.testing.assert_array_equal(y_sim.std(axis=0), 0)
    return {
        'y_sim': y_sim,
        'y_unknown_sim': y_unknown_sim,
        'y_men_sim': y_men_sim,
        'y_women_sim': y_women_sim,
    }


def sampling(iter=ITER, thin=THIN, seed=SEED):
    """Sample from posterior distribution."""
    np.random.seed(seed)
    novels_1800_1836 = datasets.garside_schöwerling_title_counts().loc[1800:]
    novels_1800_1829_by_gender = datasets.garside_schöwerling_title_counts_by_gender()
    publishers_circular = datasets.publishers_circular()
    nstc = datasets.nineteenth_century_short_title_catalogue_loced()
    casey = datasets.casey_athenaeum_novels_by_gender()
    assert len(novels_1800_1836) == 37
    # all data is made immutable, so we can hash it and cache results
    data = dict(
        y=tuple(novels_1800_1836.tolist()),
        y_unknown=tuple(novels_1800_1829_by_gender['novels_unknown'].tolist()),
        y_men=tuple(novels_1800_1829_by_gender['novels_male'].tolist()),
        y_women=tuple(novels_1800_1829_by_gender['novels_female'].tolist()),
        pc=tuple(publishers_circular.tolist()),
        casey_unknown=tuple((casey['athenaeum_novels_reviewed'] - casey['athenaeum_novels_reviewed_men'] - casey['athenaeum_novels_reviewed_women']).tolist()),  # noqa
        casey_men=tuple(casey['athenaeum_novels_reviewed_men'].tolist()),
        casey_women=tuple(casey['athenaeum_novels_reviewed_women'].tolist()),
        nstc=tuple(nstc.tolist()),
        year=tuple(range(1, 120 + 1)),
    )

    fit_extract = pystan_cache.sampling(models.make_model(), data, iter=iter, thin=thin, seed=seed)

    # NOTE: This is a HACK to get the fit.extract() dictionary to include
    # samples from the desired Negative Binomial distribution. Generating
    # samples in the `generated quantities {}` block hits overflow errors.
    num_draws = ITER // 2 // THIN * 4  # 4 chains
    mu = np.exp(fit_extract['lambda'])
    phi = np.ones((num_draws, 120)) / fit_extract['phi_y_inv'][:, np.newaxis]
    fit_extract['y_sim'] = models.neg_binomial_2_rng(mu, phi)

    assert fit_extract['lambda'].shape == fit_extract['log_odds_unknown_gender'].shape
    mu_unknown = np.exp(fit_extract['lambda']) * inv_logit(fit_extract['log_odds_unknown_gender'])
    fit_extract['y_unknown_sim'] = models.neg_binomial_2_rng(mu_unknown, phi)

    mu_men = np.exp(fit_extract['lambda']) * inv_logit(fit_extract['log_odds_men_of_known_gender']) * (1 - inv_logit(fit_extract['log_odds_unknown_gender']))
    fit_extract['y_men_sim'] = models.neg_binomial_2_rng(mu_men, phi)

    mu_women = np.exp(fit_extract['lambda']) * (1 - inv_logit(fit_extract['log_odds_men_of_known_gender'])) * (1 - inv_logit(fit_extract['log_odds_unknown_gender']))
    fit_extract['y_women_sim'] = models.neg_binomial_2_rng(mu_women, phi)

    return fit_extract


if __name__ == '__main__':
    print('quick test that things are working:')
    fit_extract = sampling(iter=2, thin=1, seed=1)
    print(fit_extract.keys())
